{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from pytorch_tabular.utils import load_covertype_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "data, cat_col_names, num_col_names, target_col = load_covertype_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Importing the Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import (\n",
    "    CategoryEmbeddingModelConfig,\n",
    ")\n",
    "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = train_test_split(data, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Bagged Prediction\n",
    "\n",
    "This is when we train different models on different folds of the data and finally ensemble them together to get the final prediction. This is a very powerful technique and can be used to increase the accuracy of the model and is something that is very frequently done in Kaggle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "Collapsed": "false",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_config = DataConfig(\n",
    "    target=[\n",
    "        target_col\n",
    "    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    ")\n",
    "trainer_config = TrainerConfig(\n",
    "    batch_size=1024,\n",
    "    max_epochs=100,\n",
    "    early_stopping=\"valid_loss\",  # Monitor valid_loss for early stopping\n",
    "    early_stopping_mode=\"min\",  # Set the mode as min because for val_loss, lower is better\n",
    "    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating\n",
    "    checkpoints=\"valid_loss\",  # Save best checkpoint monitoring val_loss\n",
    "    load_best=True,  # After training, load the best checkpoint\n",
    "    progress_bar=\"none\",  # Turning off Progress bar\n",
    "    trainer_kwargs=dict(enable_model_summary=False),  # Turning off model summary\n",
    ")\n",
    "optimizer_config = OptimizerConfig()\n",
    "\n",
    "head_config = LinearHeadConfig(\n",
    "    layers=\"\",\n",
    "    dropout=0.1,\n",
    "    initialization=(  # No additional layer in head, just a mapping layer to output_dim\n",
    "        \"kaiming\"\n",
    "    ),\n",
    ").__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
    "\n",
    "model_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"1024-512-512\",  # Number of nodes in each layer\n",
    "    activation=\"LeakyReLU\",  # Activation between each layers\n",
    "    learning_rate=1e-3,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    ")\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    verbose=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    }
   ],
   "source": [
    "tabular_model.fit(train=train)\n",
    "pred_df = tabular_model.predict(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the high level method `bagging_predict` in `TabularModel`'\n",
    "\n",
    "The arguments are as follows:\n",
    "- `cv` can either be an integer or a `KFold` object. If it is an integer, it will be treated as the number of folds in a KFold. For classification problems, it will be a StratifiedKFold. if its a `KFold` object, it will be used as is.\n",
    "- `train` is the training dataset. \n",
    "- `test` is the test dataset.\n",
    "- `aggregate` is how we ensemble the predictions. It can be `mean`, `median`, `min`, `max`, and `hard_voting`. `hard_voting` is only valid for classification. We can also pass in a custom function that takes takes in a list of 3D arrays (num_samples, num_cv, num_targets) and returns a 2D array of final probabilities (num_samples, num_targets)\n",
    "- `weights` is used for aggregating the predictions from each fold. If None, will use equal weights. This is only used when `aggregate` is \"mean\"\n",
    "- `return_raw_predictions` If True, will return the raw predictions from each fold. Defaults to False.\n",
    "\n",
    "For the entire list of arguments, please refer to the docstring."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:04:38</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">344</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2276</span><span style=\"font-weight: bold\">}</span> - INFO - Running Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>                           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m18:04:38\u001b[0m,\u001b[1;36m344\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m2276\u001b[0m\u001b[1m}\u001b[0m - INFO - Running Fold \u001b[1;36m1\u001b[0m/\u001b[1;36m3\u001b[0m                           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:08:17</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">581</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2307</span><span style=\"font-weight: bold\">}</span> - INFO - Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> prediction done                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m18:08:17\u001b[0m,\u001b[1;36m581\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m2307\u001b[0m\u001b[1m}\u001b[0m - INFO - Fold \u001b[1;36m1\u001b[0m/\u001b[1;36m3\u001b[0m prediction done                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:08:17</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">590</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2276</span><span style=\"font-weight: bold\">}</span> - INFO - Running Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>                           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m18:08:17\u001b[0m,\u001b[1;36m590\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m2276\u001b[0m\u001b[1m}\u001b[0m - INFO - Running Fold \u001b[1;36m2\u001b[0m/\u001b[1;36m3\u001b[0m                           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:13:21</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">030</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2307</span><span style=\"font-weight: bold\">}</span> - INFO - Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> prediction done                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m18:13:21\u001b[0m,\u001b[1;36m030\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m2307\u001b[0m\u001b[1m}\u001b[0m - INFO - Fold \u001b[1;36m2\u001b[0m/\u001b[1;36m3\u001b[0m prediction done                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:13:21</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">039</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2276</span><span style=\"font-weight: bold\">}</span> - INFO - Running Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>                           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m18:13:21\u001b[0m,\u001b[1;36m039\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m2276\u001b[0m\u001b[1m}\u001b[0m - INFO - Running Fold \u001b[1;36m3\u001b[0m/\u001b[1;36m3\u001b[0m                           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:16:38</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">902</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2307</span><span style=\"font-weight: bold\">}</span> - INFO - Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> prediction done                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m18:16:38\u001b[0m,\u001b[1;36m902\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m2307\u001b[0m\u001b[1m}\u001b[0m - INFO - Fold \u001b[1;36m3\u001b[0m/\u001b[1;36m3\u001b[0m prediction done                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# cross validation loop usnig sklearn\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "\n",
    "kf = KFold(n_splits=5, shuffle=True, random_state=42)\n",
    "\n",
    "\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    bagged_pred_df = tabular_model.bagging_predict(\n",
    "        cv=3, train=train, test=test, aggregate=\"mean\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Accuracy: 0.9506860443502028 | Bagged Accuracy: 0.9533778992516506\n"
     ]
    }
   ],
   "source": [
    "# Calculating metrics\n",
    "orig_acc = accuracy_score(test[target_col].values, pred_df[\"prediction\"].values)\n",
    "bagged_acc = accuracy_score(test[target_col].values, bagged_pred_df[\"prediction\"].values)\n",
    "print(f\"Original Accuracy: {orig_acc} | Bagged Accuracy: {bagged_acc}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "ad8d5d2789703c7b1c2f7bfaada1cbd3aa0ac53e2e4e1cae5da195f5520da229"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
